-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for modified sum and difference loss from https://arxiv.org/abs/2208.11428 #71
base: main
Are you sure you want to change the base?
Conversation
sai-soum
commented
Feb 8, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this will be nice to add. Main point is that we should avoid creating a new class and extend the existing class to support this behavior if possible. If that doesn't workout, at least make the new class inherit from the main class to avoid any repetitive code.
@@ -604,3 +604,175 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): | |||
return loss | |||
elif self.output == "full": | |||
return loss, sum_loss, diff_loss | |||
|
|||
class ModifiedSumAndDifferenceSTFTLoss(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of creating a new class for this we should add parameters to SumAndDifferenceSTFTLoss
in order to support this behavior. It seems like the major modification is that application of the pre-emphasis filter.
l1log_sum_loss = self.l1logstft(input_sum_mag, target_sum_mag) | ||
l1log_diff_loss = self.l1logstft(input_diff_mag, target_diff_mag) | ||
|
||
if self.loss_type == 'SClogL1': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the other difference in the distance measure. This should be able to be supported by the main class. However, if it seems easier, we could consider adding a new ModifiedSumAndDifferenceLoss
class but where it inherits from the main class so that we don't get all this repeated code (e.g. stft
, etc.)